# ---
# cmd: ["modal", "run", "06_gpu_and_ml/reinforcement-learning/grpo_trl.py::train"]
# ---

# # Train a model to solve coding problems using GRPO and TRL

# This example demonstrates how to run [GRPO](https://arxiv.org/pdf/2402.03300) on Modal using the TRL [GRPO trainer](https://huggingface.co/docs/trl/main/en/grpo_trainer)
# GRPO is a reinforcement learning algorithm introduced by DeepSeek, and was used to train DeepSeek R1.
# TRL is a reinforcement learning training library by Huggingface.

# First we perform the imports and then define the app.
from __future__ import annotations

import os
import re
import subprocess
from pathlib import Path
from typing import Iterable, Sequence

import modal

app: modal.App = modal.App("example-grpo-trl")

# We define an image where we install the TRL library.
# We also install vLLM for the next part of this example. We also use Weights & Biases for logging.
image: modal.Image = modal.Image.debian_slim().uv_pip_install(
    "trl[vllm]==0.19.1", "datasets==3.5.1", "wandb==0.17.6"
)

# We import the necessary libraries needed in the context of the image.
with image.imports():
    from datasets import Dataset, load_dataset
    from trl import GRPOConfig, GRPOTrainer

# We also define a [Modal Volume](https://modal.com/docs/guide/volumes#volumes) for storing model checkpoints.
MODELS_DIR = Path("/models")
checkpoints_volume: modal.Volume = modal.Volume.from_name(
    "example-grpo-trl-checkpoints", create_if_missing=True
)

# ## Defining the reward function

# In this example, we use the [OpenCoder-LLM/opc-sft-stage2](https://huggingface.co/datasets/OpenCoder-LLM/opc-sft-stage2) dataset to train a model to solve coding problems.

# In reinforcement learning, we define a reward function for the model. Since we are evaluating code that is generated by
# a model, we use [Modal Sandboxes](https://modal.com/docs/guide/sandbox) to evaluate the code securely.


# For each completion from the model and a test case to test the completion, we define a simple reward function.
# The function returns 1 if there are no errors, and 0 otherwise. You might want to adjust this reward function
# as the model is unlikely to learn well with this function.


@app.function()
def compute_reward(completion: str, testcase: Sequence[str]) -> int:
    sb, score = None, 0
    sb: modal.Sandbox = modal.Sandbox.create(app=app)
    code_to_execute: str = get_generated_code_and_test_cases(completion, testcase)

    try:
        p = sb.exec("python", "-c", code_to_execute, timeout=30)
        p.wait()
        return_code = p.returncode
        if return_code == 0:
            score = 1
    except Exception as e:
        print(e)
    finally:
        sb.terminate()
        return score


# We write a function that constructs a program from the model completion. This is determined based on the format of the data.
# The completions are supposed to follow the format "```python ...".
# The test cases are a list of assert statements.
# More details [here](https://huggingface.co/datasets/OpenCoder-LLM/opc-sft-stage2).
def get_generated_code_and_test_cases(completion: str, testcase: Sequence[str]) -> str:
    if "```python" in completion:
        # Find the start and end of the code block
        start_idx: int = completion.find("```python") + len("```python")
        end_idx: int = completion.find("```", start_idx)
        if end_idx != -1:
            code: str = completion[start_idx:end_idx].strip()
        else:
            code: str = completion[start_idx:].strip()
    else:
        code: str = completion.strip()

    test_cases: str = "\n".join(testcase)
    full_code: str = f"{code}\n\n{test_cases}"
    return full_code


# Finally, we define the function that is passed into the GRPOTrainer, which takes in a list of completions.
# Custom reward functions must conform to a [specific signature](https://huggingface.co/docs/trl/main/en/grpo_trainer#using-a-custom-reward-function).
def reward_helper_function(
    completions: Sequence[str], testcases: Sequence[Sequence[str]], **kwargs: object
) -> Iterable[int]:
    return compute_reward.starmap(zip(completions, testcases))


# ## Kicking off a training run


# Preprocess the data, preparing the columns that `GRPOTrainer` expects.
# We use the OpenCoder-LLM educational instruct dataset, which has (instruction, code, test case) triples validated through a Python compiler.
# More details [here](https://huggingface.co/datasets/OpenCoder-LLM/opc-sft-stage2).
def start_grpo_trainer(use_vllm=False, vllm_mode=None):
    dataset: Dataset = load_dataset(
        "OpenCoder-LLM/opc-sft-stage2", "educational_instruct", split="train"
    )
    dataset = dataset.rename_column(
        "instruction", "prompt"
    )  # Needed for the GRPO trainer
    dataset = dataset.rename_column("testcase", "testcases")
    dataset = dataset.select(range(128))  # To simplify testing.
    training_args: GRPOConfig = GRPOConfig(
        output_dir=str(MODELS_DIR),
        report_to="wandb",
        use_vllm=use_vllm,
        vllm_mode=vllm_mode,
        save_steps=1,
        max_steps=5,  # To simplify testing. Remove for production use cases.
    )
    trainer = GRPOTrainer(
        model="Qwen/Qwen2-0.5B-Instruct",
        reward_funcs=reward_helper_function,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()


# We use Weights & Biases for logging, hence we use a [Modal Secret](https://modal.com/docs/guide/secrets#secrets) with wandb credentials.
@app.function(
    image=image,
    gpu="H100",
    timeout=60 * 60 * 24,  # 24 hours
    secrets=[modal.Secret.from_name("wandb-secret")],
    volumes={"/models": checkpoints_volume},
)
def train() -> None:
    start_grpo_trainer()


# To run: `modal run --detach grpo_trl.py::train`.

# ## Speeding up training with vLLM


# vLLM can be used either in server mode (run vLLM server on separate gpu) or colocate mode (within the training process).
# In server mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP.
# This is ideal if you have dedicated GPUs for inference. More details [here](https://huggingface.co/docs/trl/main/en/grpo_trainer#-option-1-server-mode).
# Here, we use 2 GPUs. We run the GRPOTrainer on 1 of them, and the vLLM process on another.
@app.function(
    image=image,
    gpu="H100:2",
    timeout=60 * 60 * 24,  # 24 hours
    secrets=[modal.Secret.from_name("wandb-secret")],
    volumes={str(MODELS_DIR): checkpoints_volume},
)
def train_vllm_server_mode() -> None:
    env_copy = os.environ.copy()
    env_copy["CUDA_VISIBLE_DEVICES"] = "0"  # Run serve vLLM process on GPU 0

    # Start vllm-serve in the background
    subprocess.Popen(
        ["trl", "vllm-serve", "--model", "Qwen/Qwen2-0.5B-Instruct"],
        env=env_copy,
    )
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Run training process on GPU 1
    start_grpo_trainer(use_vllm=True, vllm_mode="server")


# You can execute this using `modal run --detach grpo_trl.py::train_vllm_server_mode`.

# In colocate mode, vLLM runs inside the trainer process and shares GPU memory with the training model.
# This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
# More details [here](https://huggingface.co/docs/trl/main/en/grpo_trainer#-option-2-colocate-mode).


@app.function(
    image=image,
    gpu="H100",
    timeout=60 * 60 * 24,  # 24 hours
    secrets=[modal.Secret.from_name("wandb-secret")],
    volumes={"/models": checkpoints_volume},
)
def train_vllm_colocate_mode() -> None:
    # Rank of the current process (0 for single-process training)
    os.environ["RANK"] = "0"
    # Local rank of the process on the node (0 for single-process training)
    os.environ["LOCAL_RANK"] = "0"
    # Total number of processes (1 for single-process training)
    os.environ["WORLD_SIZE"] = "1"
    # Address of the master node (localhost for single node)
    os.environ["MASTER_ADDR"] = "localhost"
    # Port for communication between processes
    os.environ["MASTER_PORT"] = "12355"
    start_grpo_trainer(use_vllm=True, vllm_mode="colocate")


# You can execute this using `modal run --detach grpo_trl.py::train_vllm_colocate_mode`.

# ## Performing inference on the trained model

# We use vLLM to perform inference on the trained model.

VLLM_PORT: int = 8000


# Once you have the model checkpoints in your Modal Volume, you can load the weights and perform inference using vLLM. For more on storing model weights on Modal, see
# [this guide](https://modal.com/docs/guide/model-weights).
# The weights path is as follows: `global_step_n/actor/huggingface` where n is the checkpoint you want (eg `global_step_5/actor/huggingface`).
# The `latest_checkpointed_iteration.txt` file stores the most recent checkpoint index.
def get_latest_checkpoint_file_path():
    checkpoint_dirs = [
        d.name
        for d in MODELS_DIR.iterdir()
        if d.is_dir() and re.match(r"^checkpoint-(\d+)$", d.name)
    ]
    if not checkpoint_dirs:
        raise FileNotFoundError("No checkpoint directories found in models dir")
    latest_checkpoint_index = max(
        int(re.match(r"^checkpoint-(\d+)$", d).group(1)) for d in checkpoint_dirs
    )
    return str(MODELS_DIR / f"checkpoint-{latest_checkpoint_index}")


# We provide the code for setting up an OpenAI compatible inference endpoint here. For more details re. serving models on vLLM, check out [this example.](https://modal.com/docs/examples/vllm_inference#deploy-the-server)

vllm_image = (
    modal.Image.debian_slim(python_version="3.12")
    .uv_pip_install(
        "vllm==0.9.1",
        "flashinfer-python==0.2.6.post1",
        extra_index_url="https://download.pytorch.org/whl/cu128",
        extra_options="--index-strategy unsafe-best-match",
    )
    .env({"VLLM_USE_V1": "1"})
)

vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)


@app.function(
    image=vllm_image,
    gpu="H100",
    scaledown_window=15 * 60,  # How long should we stay up with no requests?
    timeout=10 * 60,  # How long should we wait for container start?
    volumes={"/root/.cache/vllm": vllm_cache_vol, MODELS_DIR: checkpoints_volume},
)
@modal.concurrent(
    max_inputs=32
)  # How many requests can one replica handle? tune carefully!
@modal.web_server(port=VLLM_PORT, startup_timeout=10 * 60)
def serve():
    latest_checkpoint_file_path = get_latest_checkpoint_file_path()

    cmd = [
        "vllm",
        "serve",
        "--uvicorn-log-level=info",
        latest_checkpoint_file_path,
        "--host",
        "0.0.0.0",
        "--port",
        str(VLLM_PORT),
    ]
    subprocess.Popen(" ".join(cmd), shell=True)


# You can then deploy the server using `modal deploy grpo_trl.py`, which gives you a custom url. You can then query it using the following curl command:

# ```bash
# curl -X POST <url>/v1/chat/completions \
#   -H 'Content-Type: application/json' \
#   -d '{
#     "messages": [
#       {"role": "system", "content": "You are a helpful assistant for solving math problems."},
#       {"role": "user", "content": "James had 4 apples. Mary gave him 2 and he ate 1. How many does he have left?"}
#     ],
#     "temperature": 0.7
#   }'
# ```

# or in the [following ways](https://modal.com/docs/examples/vllm_inference#interact-with-the-server).
